{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n",
        "\n",
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ],
      "metadata": {
        "cellView": "form",
        "id": "faayYQYrQzY3"
            },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JjAt1GesQ9sg"
      },
      "source": [
        "# Use RunInference in Apache Beam\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_pytorch_tensorflow_sklearn.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_pytorch_tensorflow_sklearn.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A8xNRyZMW1yK"
      },
      "source": [
        "You can use Apache Beam versions 2.40.0 and later with the [RunInference API](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.RunInference) for local and remote inference with batch and streaming pipelines.\n",
        "The RunInference API leverages Apache Beam concepts, such as the `BatchElements` transform and the `Shared` class, to support models in your pipelines that create transforms optimized for machine learning inference.\n",
        "\n",
        "For more information about the RunInference API, see [About Beam ML](https://beam.apache.org/documentation/ml/about-ml) in the Apache Beam documentation."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A8xNRyZMW1yK"
      },
      "source": [
        "This example demonstrates how to use the RunInference API with three popular ML frameworks: PyTorch, TensorFlow, and scikit-learn. The three pipelines use a text classification model for generating predictions.\n",
        "\n",
        "Follow these steps to build a pipeline:\n",
        "* Read the images.\n",
        "* If needed, preprocess the text.\n",
        "* Run inference with the PyTorch, TensorFlow, or Scikit-learn model.\n",
        "* If needed, postprocess the output."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CTtBTpsHZFCk"
      },
      "source": [
        "## RunInference with a PyTorch model\n",
        "\n",
        "This section demonstrates how to use the RunInference API with a PyTorch model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5kkjbcIzZIf6"
      },
      "source": [
        "### Install dependencies\n",
        "\n",
        "First, download and install the dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "MRASwRTxY-2u",
        "outputId": "28760c59-c4dc-4486-dbd2-e7ac2c92c3b8"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade pip\n",
        "!pip install apache_beam[interactive,gcp]>=2.40.0\n",
        "!pip install transformers\n",
        "!pip install google-api-core==1.32"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ObRPUrlEbjHj"
      },
      "source": [
        "### Install the model\n",
        "\n",
        "This example uses a pretrained text classification model, [distilbert-base-uncased-finetuned-sst-2-english](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english?text=I+like+you.+I+love+you). This model is a checkpoint of `DistilBERT-base-uncased`, fine-tuned on the SST-2 dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "vfDyy4WNQaJM",
        "outputId": "75683116-f415-4956-f44c-baa953c564e1"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Error: Failed to call git rev-parse --git-dir --show-toplevel: \"fatal: not a git repository (or any of the parent directories): .git\\n\"\n",
            "Git LFS initialized.\n",
            "fatal: destination path 'distilbert-base-uncased-finetuned-sst-2-english' already exists and is not an empty directory.\n",
            "'=2.40.0'   distilbert-base-uncased-finetuned-sst-2-english   sample_data\n"
          ]
        }
      ],
      "source": [
        "! git lfs install\n",
        "! git clone https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english\n",
        "! ls"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vA1UmbFRb5C-"
      },
      "source": [
        "### Install helper functions\n",
        "\n",
        "The model also uses helper functions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c4ZwN8wsbvgK"
      },
      "outputs": [],
      "source": [
        "from collections import defaultdict\n",
        "\n",
        "import torch\n",
        "from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig\n",
        "\n",
        "import apache_beam as beam\n",
        "from apache_beam.ml.inference import RunInference\n",
        "from apache_beam.ml.inference.base import PredictionResult, KeyedModelHandler\n",
        "from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor\n",
        "\n",
        "\n",
        "class HuggingFaceStripBatchingWrapper(DistilBertForSequenceClassification):\n",
        "  \"\"\"Wrapper around HugginFace model because RunInference requires a batch\n",
        "  as a list of dicts instead of a dict of lists. Another workaround can be found\n",
        "  here where they disable batching instead.\n",
        "  https://github.com/apache/beam/blob/master/sdks/python/apache_beam/examples/inference/pytorch_language_modeling.py\"\"\"\n",
        "  def forward(self, **kwargs):\n",
        "    output = super().forward(**kwargs)\n",
        "    return [dict(zip(output, v)) for v in zip(*output.values())]\n",
        "\n",
        "\n",
        "\n",
        "class Tokenize(beam.DoFn):\n",
        "  def __init__(self, model_name: str):\n",
        "    self._model_name = model_name\n",
        "\n",
        "  def setup(self):\n",
        "    self._tokenizer = DistilBertTokenizer.from_pretrained(self._model_name)\n",
        "  \n",
        "  def process(self, text_input: str):\n",
        "    # Pad the token tensors to max length to make sure that all of the tensors\n",
        "    # are of the same length and stack-able by the RunInference API. Normally, you would batch first\n",
        "    # then tokenize the batch, padding each tensor the max length in the batch.\n",
        "    # See: https://beam.apache.org/documentation/ml/about-ml/#unable-to-batch-tensor-elements\n",
        "    tokens = self._tokenizer(text_input, return_tensors='pt', padding='max_length', max_length=512)\n",
        "    # Squeeze because tokenization adds an extra dimension, which is empty,\n",
        "    # in this case because we tokenize one element at a time.\n",
        "    tokens = {key: torch.squeeze(val) for key, val in tokens.items()}\n",
        "    return [(text_input, tokens)]\n",
        "\n",
        "class PostProcessor(beam.DoFn):\n",
        "  def process(self, tuple_):\n",
        "    text_input, prediction_result = tuple_\n",
        "    softmax = torch.nn.Softmax(dim=-1)(prediction_result.inference['logits']).detach().numpy()\n",
        "    return [{\"input\": text_input, \"softmax\": softmax}]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WYYbQTMWctkW"
      },
      "source": [
        "### Run the pipeline\n",
        "\n",
        "This section demonstrates how to create and run the RunInference pipeline."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lLb8D2n2n09n"
      },
      "outputs": [],
      "source": [
        "inputs = [\n",
        "    \"This is the worst food I have ever eaten\",\n",
        "    \"In my soul and in my heart, I’m convinced I’m wrong!\",\n",
        "    \"Be with me always—take any form—drive me mad! only do not leave me in this abyss, where I cannot find you!\",\n",
        "    \"Do I want to live? Would you like to live with your soul in the grave?\",\n",
        "    \"Honest people don’t hide their deeds.\",\n",
        "    \"Nelly, I am Heathcliff!  He’s always, always in my mind: not as a pleasure, any more than I am always a pleasure to myself, but as my own being.\",\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        },
        "id": "TDmMARxGb751",
        "outputId": "437e168a-b4c5-463b-ce5f-09a8cb8d8191"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:10: FutureWarning: PytorchModelHandlerKeyedTensor is experimental. No backwards-compatibility guarantees.\n",
            "  # Remove the CWD from sys.path while we load stuff.\n",
            "WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.\n"
          ]
        },
        {
          "data": {
            "application/javascript": "\n        if (typeof window.interactive_beam_jquery == 'undefined') {\n          var jqueryScript = document.createElement('script');\n          jqueryScript.src = 'https://code.jquery.com/jquery-3.4.1.slim.min.js';\n          jqueryScript.type = 'text/javascript';\n          jqueryScript.onload = function() {\n            var datatableScript = document.createElement('script');\n            datatableScript.src = 'https://cdn.datatables.net/1.10.20/js/jquery.dataTables.min.js';\n            datatableScript.type = 'text/javascript';\n            datatableScript.onload = function() {\n              window.interactive_beam_jquery = jQuery.noConflict(true);\n              window.interactive_beam_jquery(document).ready(function($){\n                \n              });\n            }\n            document.head.appendChild(datatableScript);\n          };\n          document.head.appendChild(jqueryScript);\n        } else {\n          window.interactive_beam_jquery(document).ready(function($){\n            \n          });\n        }"
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/dill/_dill.py:472: FutureWarning: PytorchModelHandlerKeyedTensor is experimental. No backwards-compatibility guarantees.\n",
            "  obj = StockUnpickler.load(self)\n",
            "/usr/local/lib/python3.7/dist-packages/dill/_dill.py:472: FutureWarning: PytorchModelHandlerKeyedTensor is experimental. No backwards-compatibility guarantees.\n",
            "  obj = StockUnpickler.load(self)\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Input: This is the worst food I have ever eaten -> negative=99.9777%/positive=0.0223%\n",
            "Input: In my soul and in my heart, I’m convinced I’m wrong! -> negative=1.6313%/positive=98.3687%\n",
            "Input: Be with me always—take any form—drive me mad! only do not leave me in this abyss, where I cannot find you! -> negative=62.1188%/positive=37.8812%\n",
            "Input: Do I want to live? Would you like to live with your soul in the grave? -> negative=73.6841%/positive=26.3159%\n",
            "Input: Honest people don’t hide their deeds. -> negative=0.2377%/positive=99.7623%\n",
            "Input: Nelly, I am Heathcliff!  He’s always, always in my mind: not as a pleasure, any more than I am always a pleasure to myself, but as my own being. -> negative=0.0672%/positive=99.9328%\n"
          ]
        }
      ],
      "source": [
        "model_handler = PytorchModelHandlerKeyedTensor(\n",
        "    state_dict_path=\"./distilbert-base-uncased-finetuned-sst-2-english/pytorch_model.bin\",\n",
        "    model_class=HuggingFaceStripBatchingWrapper,\n",
        "    model_params={\"config\": DistilBertConfig.from_pretrained(\"./distilbert-base-uncased-finetuned-sst-2-english/config.json\")},\n",
        "    device='cuda:0')\n",
        "\n",
        "keyed_model_handler = KeyedModelHandler(model_handler)\n",
        "\n",
        "with beam.Pipeline() as pipeline:\n",
        "  _ = (pipeline | \"Create inputs\" >> beam.Create(inputs)\n",
        "                | \"Tokenize\" >> beam.ParDo(Tokenize(\"distilbert-base-uncased-finetuned-sst-2-english\"))\n",
        "                | \"Inference\" >> RunInference(model_handler=keyed_model_handler)\n",
        "                | \"Postprocess\" >> beam.ParDo(PostProcessor())\n",
        "                | \"Print\" >> beam.Map(lambda x: print(f\"Input: {x['input']} -> negative={100 * x['softmax'][0]:.4f}%/positive={100 * x['softmax'][1]:.4f}%\"))\n",
        "  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7KXeaQg3eCcp"
      },
      "source": [
        "## RunInference with a TensorFlow model\n",
        "\n",
        "This section demonstrates how to use the RunInference API with a TensorFlow model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hEHxNka4eOhC"
      },
      "source": [
        "Note: Tensorflow models are supported through `tfx-bsl`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8KyXULYbeYlD"
      },
      "source": [
        "### Install dependencies\n",
        "\n",
        "First, download and install the dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "uqWJhQBlc4oT",
        "outputId": "2a17a966-fe2d-45d8-b6b9-02534f40c9a8"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "apache_beam"
                ]
              }
            }
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.colab-display-data+json": {
              "pip_warning": {
                "packages": [
                  "google"
                ]
              }
            }
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "!pip install --upgrade pip\n",
        "!pip install google-api-core==1.32\n",
        "!pip install apache_beam[interactive,gcp]==2.41.0\n",
        "!pip install tensorflow==2.8\n",
        "!pip install tfx_bsl\n",
        "!pip install tensorflow-text==2.8.1"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "642maF_redwC"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_text as text\n",
        "from scipy.special import expit\n",
        "\n",
        "import apache_beam as beam\n",
        "import tfx_bsl\n",
        "from tfx_bsl.public.beam import RunInference\n",
        "from tfx_bsl.public import tfxio\n",
        "from tfx_bsl.public.proto import model_spec_pb2\n",
        "from tfx_bsl.public.tfxio import TFExampleRecord\n",
        "from tensorflow_serving.apis import prediction_log_pb2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "h2JP7zsqerCT"
      },
      "source": [
        "### Install the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ydYQ_5EyfeEM"
      },
      "source": [
        "Download a pretrained binary classifier to perform sentiment analysis on an IMDB dataset from Google Cloud Storage.\n",
        "This model was trained by following this [TensorFlow text classification tutorial](https://www.tensorflow.org/tutorials/keras/text_classification)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BucRWly0flz8"
      },
      "outputs": [],
      "source": [
        "model_dir = \"gs://apache-beam-testing-ml-examples/imdb_bert\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GZ-Ioc8ZfyIT"
      },
      "source": [
        "### Install helper functions\n",
        "\n",
        "The model also uses helper functions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pZ0LNtHUfsRq"
      },
      "outputs": [],
      "source": [
        "class ExampleProcessor:\n",
        "  \"\"\"\n",
        "  Process the raw text input to a format suitable for RunInference.\n",
        "  TensorFlow model handler expects a serialized tf.Example as input\n",
        "  \"\"\"\n",
        "  def create_example(self, feature):\n",
        "    return tf.train.Example(\n",
        "        features=tf.train.Features(\n",
        "              feature={'x' : self.create_feature(feature)})\n",
        "        )\n",
        "\n",
        "  def create_feature(self, element):\n",
        "    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[element]))\n",
        "\n",
        "class PredictionProcessor(beam.DoFn):\n",
        "   \"\"\"\n",
        "   Process the RunInference output to return the input text and the softmax probability\n",
        "   \"\"\"\n",
        "   def process(\n",
        "           self,\n",
        "           element: prediction_log_pb2.PredictionLog):\n",
        "       predict_log = element.predict_log\n",
        "       input_value = tf.train.Example.FromString(predict_log.request.inputs['text'].string_val[0])\n",
        "       output_value = predict_log.response.outputs\n",
        "      #  print(output_value)\n",
        "       yield (f\"input is [{input_value.features.feature['x'].bytes_list.value}] output is {expit(output_value['classifier'].float_val)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PZVwI4BbgaAI"
      },
      "source": [
        "### Prepare the input\n",
        "\n",
        "This section demonstrates how to prepare the input for your model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TOXX1KMKi_mm"
      },
      "outputs": [],
      "source": [
        "inputs = np.array([\n",
        "    b\"this is such an amazing movie\",\n",
        "    b\"The movie was great\",\n",
        "    b\"The movie was okish\",\n",
        "    b\"The movie was terrible\"\n",
        "])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O2Y15WmfgZXQ"
      },
      "outputs": [],
      "source": [
        "input_strings_file = 'input_strings.tfrecord'\n",
        "\n",
        "# Because RunInference is expecting a serialized tf.example as an input, preprocess the input.\n",
        "# Write the processed input to a file. \n",
        "# You can also do this preprocessing as a pipeline step by using beam.Map().\n",
        "\n",
        "with tf.io.TFRecordWriter(input_strings_file) as writer:\n",
        " for i in inputs:\n",
        "   example = ExampleProcessor().create_example(feature=i)\n",
        "   writer.write(example.SerializeToString())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BYkQl_l8gRgo"
      },
      "source": [
        "### Run the pipeline\n",
        "\n",
        "This section demonstrates how to create and run the RunInference pipeline."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "uh5bMhxdgA7Q",
        "outputId": "2a22059f-519c-44f7-e36f-59e09b1cb24a"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:From /usr/local/lib/python3.7/dist-packages/tfx_bsl/beam/run_inference.py:615: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.\n",
            "WARNING:apache_beam.io.tfrecordio:Couldn't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "input is [[b'this is such an amazing movie']] output is [0.99906057]\n",
            "input is [[b'The movie was great']] output is [0.99307914]\n",
            "input is [[b'The movie was okish']] output is [0.03274685]\n",
            "input is [[b'The movie was terrible']] output is [0.00680008]\n"
          ]
        }
      ],
      "source": [
        "saved_model_spec = model_spec_pb2.SavedModelSpec(model_path=model_dir)\n",
        "inference_spec_type = model_spec_pb2.InferenceSpecType(saved_model_spec=saved_model_spec)\n",
        "\n",
        "# A Beam I/O that reads a file of serialized tf.Examples\n",
        "tfexample_beam_record = TFExampleRecord(file_pattern='input_strings.tfrecord')\n",
        "\n",
        "with beam.Pipeline() as pipeline:\n",
        "    _ = ( pipeline | \"Create Input PCollection\" >> tfexample_beam_record.RawRecordBeamSource()\n",
        "                   | \"Do Inference\" >> RunInference(model_spec_pb2.InferenceSpecType(\n",
        "                                  saved_model_spec=model_spec_pb2.SavedModelSpec(model_path=model_dir)))\n",
        "                   | \"Post Process\" >> beam.ParDo(PredictionProcessor())\n",
        "                   | beam.Map(print)\n",
        "        )\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8wBUckzHjGV6"
      },
      "source": [
        "## RunInference with scikit-learn\n",
        "\n",
        "This section demonstrates how to use the RunInference API with scikit-learn."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6ArL_55kjxkO"
      },
      "source": [
        "### Install dependencies\n",
        "\n",
        "First, download and install the dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R4p6Mil0jxSy"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade pip\n",
        "!pip install google-api-core==1.32\n",
        "!pip install apache_beam[interactive,gcp]==2.41.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_YtRRxh1hLag"
      },
      "outputs": [],
      "source": [
        "import pickle\n",
        "\n",
        "import apache_beam as beam\n",
        "from apache_beam.ml.inference import RunInference\n",
        "from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy, ModelFileType"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-7ABKlZvkFHy"
      },
      "source": [
        "### Install the model\n",
        "\n",
        "To classify movie reviews as either positive or negative, train and save a sentiment analysis pipeline about movie reviews."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WI_UXluPkRYq"
      },
      "source": [
        "This model was trained by following this [scikit-learn tutorial](https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html#exercise-2-sentiment-analysis-on-movie-reviews)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "model_dir = \"gs://apache-beam-testing-ml-examples/sklearn-text-classification/sklearn_sentiment_analysis_pipeline.pkl\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KL4Cx8s0mBqn"
      },
      "source": [
        "### Run the pipeline\n",
        "\n",
        "This section demonstrates how to create and run the RunInference pipeline."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kyN2Aco8l7SR"
      },
      "outputs": [],
      "source": [
        "inputs = [\n",
        "    \"In my soul and in my heart, I’m convinced I’m wrong!\",\n",
        "    \"Be with me always—take any form—drive me mad! only do not leave me in this abyss, where I cannot find you!\",\n",
        "    \"Do I want to live? Would you like to live with your soul in the grave?\",\n",
        "    \"Honest people don’t hide their deeds.\",\n",
        "    \"Nelly, I am Heathcliff!  He’s always, always in my mind: not as a pleasure, any more than I am always a pleasure to myself, but as my own being.\",\n",
        "]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QnQ6ePcgmEeR",
        "outputId": "b0d4d31a-76c1-49e4-aa5a-8003a95bbb47"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "input: In my soul and in my heart, I’m convinced I’m wrong! -> negative\n",
            "input: Be with me always—take any form—drive me mad! only do not leave me in this abyss, where I cannot find you! -> positive\n",
            "input: Do I want to live? Would you like to live with your soul in the grave? -> positive\n",
            "input: Honest people don’t hide their deeds. -> negative\n",
            "input: Nelly, I am Heathcliff!  He’s always, always in my mind: not as a pleasure, any more than I am always a pleasure to myself, but as my own being. -> negative\n"
          ]
        }
      ],
      "source": [
        "# Choose an sklearn model handler based on the input data type:\n",
        "# 1. SklearnModelHandlerNumpy: For using numpy arrays as input.\n",
        "# 2. SklearnModelHandlerPandas: For using pandas dataframes as input.\n",
        "\n",
        "# The sklearn model handler supports loading two serialized formats:\n",
        "# 1. ModelFileType.PICKLE: For models saved using pickle.\n",
        "# 2. ModelFileType.JOBLIB: For models saved using Joblib.\n",
        "\n",
        "model_handler = SklearnModelHandlerNumpy(model_uri=model_dir, model_file_type=ModelFileType.PICKLE)\n",
        "\n",
        "with beam.Pipeline() as pipeline:\n",
        "  _ = (pipeline | \"Create inputs\" >> beam.Create(inputs)\n",
        "                | \"Inference\" >> RunInference(model_handler=model_handler)\n",
        "                | \"Print\" >> beam.Map(lambda x: print(f\"input: {x.example} -> {'positive' if x.inference == 0 else 'negative'}\"))\n",
        "  )"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "5kkjbcIzZIf6",
        "vA1UmbFRb5C-",
        "-7ABKlZvkFHy"
      ],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
